#!/usr/bin/env python
# coding=utf-8
"""
Build a cropped dataset from COCO 2017.
This script reads images from local COCO 2017 dataset folders,
creates random crops from them, and saves them to a new directory.
"""

import os
import random
import argparse
import concurrent.futures
from typing import List, Tuple

import numpy as np
from PIL import Image
from tqdm import tqdm

def get_image_paths(dir_path: str) -> List[str]:
    """Get all image paths in a directory, filtering for .jpg."""
    image_paths = []
    if not os.path.isdir(dir_path):
        return image_paths
        
    for file in os.listdir(dir_path):
        if file.lower().endswith('.jpg'):
            image_paths.append(os.path.join(dir_path, file))
    
    return sorted(image_paths)

def random_crop(image_path: str, crop_size: int) -> np.ndarray:
    """Create a random crop of size crop_size x crop_size from an image."""
    try:
        image = Image.open(image_path).convert('RGB')
    except (IOError, OSError) as e:
        print(f"Warning: Could not open image {image_path}, skipping. Error: {e}")
        return None

    width, height = image.size
    
    # 如果图像小于裁剪尺寸，则跳过此图像或进行缩放
    # 这里我们选择跳过，因为COCO大部分图像都足够大
    if width < crop_size or height < crop_size:
        # 或者，可以像原脚本一样进行放大，但对于COCO通常不需要
        # print(f"Warning: Image {image_path} is smaller than crop size, skipping.")
        # return None
        # 为保持与原脚本逻辑一致，我们进行放大
        scale = max(crop_size / width, crop_size / height) * 1.05  # 增加5%的余量
        new_width, new_height = int(width * scale), int(height * scale)
        image = image.resize((new_width, new_height), Image.LANCZOS)
        width, height = image.size

    left = random.randint(0, width - crop_size)
    top = random.randint(0, height - crop_size)
    
    crop = image.crop((left, top, left + crop_size, top + crop_size))
    return np.array(crop)

def process_image(args: Tuple[str, str, int, int]) -> bool:
    """
    Process a single image to create and save a random crop.
    Returns True on success, False on failure.
    """
    source_image_path, save_path, crop_size, _ = args
    
    # 创建随机裁剪
    crop_array = random_crop(source_image_path, crop_size)
    
    if crop_array is None:
        return False

    # 保存裁剪后的图像
    try:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        Image.fromarray(crop_array).save(save_path, 'PNG')
        return True
    except Exception as e:
        print(f"Error saving crop to {save_path}: {e}")
        return False


def create_crops_for_split(
    source_dir: str, 
    output_dir: str, 
    crop_size: int, 
    num_crops: int,
    split_name: str
) -> None:
    """Create random crops for a specific data split (train/val/test)."""
    print(f"\nProcessing split: {split_name}")
    print(f"Source: {source_dir}")
    print(f"Destination: {output_dir}")

    os.makedirs(output_dir, exist_ok=True)
    
    image_paths = get_image_paths(source_dir)
    if not image_paths:
        print(f"Warning: No images found in {source_dir}. Skipping '{split_name}' split.")
        return

    print(f"Found {len(image_paths)} source images for the '{split_name}' split.")
    
    # 创建任务列表
    tasks = []
    for i in range(num_crops):
        # 为每次裁剪随机选择一张源图像（允许重复使用）
        source_img_path = random.choice(image_paths)
        # 格式化输出文件名，从0开始并补零
        save_path = os.path.join(output_dir, f"{i:08d}.png")
        tasks.append((source_img_path, save_path, crop_size, i))
    
    # 使用线程池并行处理任务
    with concurrent.futures.ThreadPoolExecutor() as executor:
        list(tqdm(
            executor.map(process_image, tasks),
            total=len(tasks),
            desc=f"Creating {split_name} crops"
        ))

def main():
    """Main function to build the cropped COCO dataset."""
    parser = argparse.ArgumentParser(description="Build a cropped dataset from COCO 2017")
    parser.add_argument("--source_dir", type=str, default="",
                        help="Base source directory of COCO 2017 dataset")
    parser.add_argument("--output_dir", type=str, default="",
                        help="Base output directory for cropped images")
    parser.add_argument("--crop_size", type=int, default=256, 
                        help="Size of the crops (n x n)")
    parser.add_argument("--train_crops", type=int, default=10000, 
                        help="Number of training crops to generate")
    parser.add_argument("--val_crops", type=int, default=5000, 
                        help="Number of validation crops to generate")
    parser.add_argument("--test_crops", type=int, default=5000, 
                        help="Number of test crops to generate")
    parser.add_argument("--seed", type=int, default=42, 
                        help="Random seed for reproducibility")
    
    args = parser.parse_args()
    
    # 设置随机种子以保证结果可复现
    random.seed(args.seed)
    np.random.seed(args.seed)
    
    print("Starting COCO dataset cropping process...")
    print(f"Source directory: {args.source_dir}")
    print(f"Output directory: {args.output_dir}")

    # 定义要处理的数据集划分
    # 格式: (源子文件夹, 目标子文件夹, 裁剪数量, 划分名称)
    splits_to_process = [
        ('train2017', 'train', args.train_crops),
        ('val2017', 'val', args.val_crops),
        ('test2017', 'test', args.test_crops)
    ]

    for source_folder, dest_folder, num_crops in splits_to_process:
        if num_crops <= 0:
            print(f"Skipping '{dest_folder}' as number of crops is set to 0.")
            continue

        full_source_dir = os.path.join(args.source_dir, source_folder)
        full_output_dir = os.path.join(args.output_dir, dest_folder)
        
        create_crops_for_split(
            source_dir=full_source_dir,
            output_dir=full_output_dir,
            crop_size=args.crop_size,
            num_crops=num_crops,
            split_name=dest_folder
        )
    
    print(f"\nDone! Cropped dataset saved in: {args.output_dir}")

if __name__ == "__main__":
    main()